{ "cells": [ { "cell_type": "markdown", "id": "32szzPY4RyWO", "metadata": { "id": "32szzPY4RyWO" }, "source": [ "### **4. R learner**\n", "The idea of classical R-learner came from Robinson 1988 [3] and was formalized by Nie and Wager in 2020 [2]. The main idea of R learner starts from the partially linear model setup, in which we assume that\n", "\\begin{equation}\n", " \\begin{aligned}\n", " R&=A\\tau(S)+g_0(S)+U,\\\\\n", " A&=m_0(S)+V,\n", " \\end{aligned}\n", "\\end{equation}\n", "where $U$ and $V$ satisfies $\\mathbb{E}[U|D,X]=0$, $\\mathbb{E}[V|X]=0$.\n", "\n", "After several manipulations, it’s easy to get\n", "\\begin{equation}\n", "\tR-\\mathbb{E}[R|S]=\\tau(S)\\cdot(A-\\mathbb{E}[A|S])+\\epsilon.\n", "\\end{equation}\n", "Define $m_0(X)=\\mathbb{E}[A|S]$ and $l_0(X)=\\mathbb{E}[R|S]$. A natural way to estimate $\\tau(X)$ is given below, which is also the main idea of R-learner:\n", "\n", "**Step 1**: Regress $R$ on $S$ to obtain model $\\hat{\\eta}(S)=\\hat{\\mathbb{E}}[R|S]$; and regress $A$ on $S$ to obtain model $\\hat{m}(S)=\\hat{\\mathbb{E}}[A|S]$.\n", "\n", "**Step 2**: Regress outcome residual $R-\\hat{l}(S)$ on propensity score residual $A-\\hat{m}(S)$.\n", "\n", "That is,\n", "\\begin{equation}\n", "\t\\hat{\\tau}(S)=\\arg\\min_{\\tau}\\left\\{\\mathbb{E}_n\\left[\\left(\\{R_i-\\hat{\\eta}(S_i)\\}-\\{A_i-\\hat{m}(S_i)\\}\\cdot\\tau(S_i)\\right)^2\\right]\\right\\}\t\n", "\\end{equation}\n", "\n", "The easiest way to do so is to specify $\\hat{\\tau}(S)$ to the linear function class. In this case, $\\tau(S)=S\\beta$, and the problem becomes to estimate $\\beta$ by solving the following linear regression:\n", "\\begin{equation}\n", "\t\\hat{\\beta}=\\arg\\min_{\\beta}\\left\\{\\mathbb{E}_n\\left[\\left(\\{R_i-\\hat{\\eta}(S_i)\\}-\\{A_i-\\hat{m}(S_i)\\} S_i\\cdot \\beta\\right)^2\\right]\\right\\}.\n", "\\end{equation}\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.linear_model import LogisticRegression \n", "\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL\n", "from causaldm.learners.CEL.Single_Stage.Rlearner import Rlearner\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingageDramaSci-Figender_Moccupation_academic/educatoroccupation_college/grad studentoccupation_executive/managerialoccupation_otheroccupation_technician/engineer
048.01193.04.025.01.00.01.00.01.00.00.00.0
148.0919.04.025.01.00.01.00.01.00.00.00.0
248.0527.05.025.01.00.01.00.01.00.00.00.0
348.01721.04.025.01.00.01.00.01.00.00.00.0
448.0150.04.025.01.00.01.00.01.00.00.00.0
.......................................
656375878.03300.02.025.00.01.00.00.00.00.01.00.0
656385878.01391.01.025.00.01.00.00.00.00.01.00.0
656395878.0185.04.025.00.01.00.00.00.00.01.00.0
656405878.02232.01.025.00.01.00.00.00.00.01.00.0
656415878.0426.03.025.00.01.00.00.00.00.01.00.0
\n", "

65642 rows × 12 columns

\n", "
" ], "text/plain": [ " user_id movie_id rating age Drama Sci-Fi gender_M \\\n", "0 48.0 1193.0 4.0 25.0 1.0 0.0 1.0 \n", "1 48.0 919.0 4.0 25.0 1.0 0.0 1.0 \n", "2 48.0 527.0 5.0 25.0 1.0 0.0 1.0 \n", "3 48.0 1721.0 4.0 25.0 1.0 0.0 1.0 \n", "4 48.0 150.0 4.0 25.0 1.0 0.0 1.0 \n", "... ... ... ... ... ... ... ... \n", "65637 5878.0 3300.0 2.0 25.0 0.0 1.0 0.0 \n", "65638 5878.0 1391.0 1.0 25.0 0.0 1.0 0.0 \n", "65639 5878.0 185.0 4.0 25.0 0.0 1.0 0.0 \n", "65640 5878.0 2232.0 1.0 25.0 0.0 1.0 0.0 \n", "65641 5878.0 426.0 3.0 25.0 0.0 1.0 0.0 \n", "\n", " occupation_academic/educator occupation_college/grad student \\\n", "0 0.0 1.0 \n", "1 0.0 1.0 \n", "2 0.0 1.0 \n", "3 0.0 1.0 \n", "4 0.0 1.0 \n", "... ... ... \n", "65637 0.0 0.0 \n", "65638 0.0 0.0 \n", "65639 0.0 0.0 \n", "65640 0.0 0.0 \n", "65641 0.0 0.0 \n", "\n", " occupation_executive/managerial occupation_other \\\n", "0 0.0 0.0 \n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "... ... ... \n", "65637 0.0 1.0 \n", "65638 0.0 1.0 \n", "65639 0.0 1.0 \n", "65640 0.0 1.0 \n", "65641 0.0 1.0 \n", "\n", " occupation_technician/engineer \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "... ... \n", "65637 0.0 \n", "65638 0.0 \n", "65639 0.0 \n", "65640 0.0 \n", "65641 0.0 \n", "\n", "[65642 rows x 12 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the MovieLens data\n", "MovieLens_CEL = _env_getdata_CEL.get_movielens_CEL()\n", "MovieLens_CEL.pop(MovieLens_CEL.columns[0])\n", "MovieLens_CEL = MovieLens_CEL[MovieLens_CEL.columns.drop(['Comedy','Action', 'Thriller'])]\n", "MovieLens_CEL" ] }, { "cell_type": "code", "execution_count": 24, "id": "J__3Ozs7Uxxs", "metadata": { "id": "J__3Ozs7Uxxs" }, "outputs": [], "source": [ "n = len(MovieLens_CEL)" ] }, { "cell_type": "code", "execution_count": 3, "id": "04f98da4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['age', 'gender_M', 'occupation_academic/educator',\n", " 'occupation_college/grad student', 'occupation_executive/managerial',\n", " 'occupation_other', 'occupation_technician/engineer'],\n", " dtype='object')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "userinfo_index = np.array([3,6,7,8,9,10,11])\n", "MovieLens_CEL.columns[userinfo_index]" ] }, { "cell_type": "code", "execution_count": 21, "id": "f8abaf99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "estimate with R-learner\n", "fold 1,testing r2 y_learner: 0.019, ps_learner: 0.734\n", "fold 2,testing r2 y_learner: 0.015, ps_learner: 0.739\n", "fold 3,testing r2 y_learner: 0.017, ps_learner: 0.740\n", "fold 4,testing r2 y_learner: 0.017, ps_learner: 0.736\n", "fold 5,testing r2 y_learner: 0.018, ps_learner: 0.725\n", "fold 1, training r2 R-learner: 0.028, testing r2 R-learner: 0.028\n", "fold 2, training r2 R-learner: 0.031, testing r2 R-learner: 0.020\n", "fold 3, training r2 R-learner: 0.029, testing r2 R-learner: 0.029\n", "fold 4, training r2 R-learner: 0.030, testing r2 R-learner: 0.024\n", "fold 5, training r2 R-learner: 0.030, testing r2 R-learner: 0.024\n" ] } ], "source": [ "# R-learner for HTE estimation\n", "np.random.seed(1)\n", "outcome = 'rating'\n", "treatment = 'Drama'\n", "controls = ['age', 'gender_M', 'occupation_academic/educator',\n", " 'occupation_college/grad student', 'occupation_executive/managerial',\n", " 'occupation_other', 'occupation_technician/engineer']\n", "n_folds = 5\n", "y_model = GradientBoostingRegressor(max_depth=2)\n", "ps_model = LogisticRegression()\n", "Rlearner_model = GradientBoostingRegressor(max_depth=2)\n", "\n", "HTE_R_learner = Rlearner(MovieLens_CEL, outcome, treatment, controls, n_folds, y_model, ps_model, Rlearner_model)\n", "HTE_R_learner = HTE_R_learner.to_numpy()" ] }, { "cell_type": "markdown", "id": "FA-F8Jc_T5Lz", "metadata": { "id": "FA-F8Jc_T5Lz" }, "source": [ "Let's focus on the estimated HTEs for three randomly chosen users:" ] }, { "cell_type": "code", "execution_count": 22, "id": "GvHnTOxmT5Lz", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 318, "status": "ok", "timestamp": 1676750150517, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "GvHnTOxmT5Lz", "outputId": "7b0b76fd-f5ac-4ab8-a3c0-188e15484fe7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R-learner: [0.05127254 0.08881288 0.10304225]\n" ] } ], "source": [ "print(\"R-learner: \",HTE_R_learner[np.array([0,1000,5000])])" ] }, { "cell_type": "code", "execution_count": 23, "id": "48136320", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.0755 out of 5 points.\n" ] } ], "source": [ "ATE_R_learner = np.sum(HTE_R_learner)/n\n", "print(\"Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by\",round(ATE_R_learner,4), \"out of 5 points.\")" ] }, { "cell_type": "markdown", "id": "mVAZTZYTUKJ6", "metadata": { "id": "mVAZTZYTUKJ6" }, "source": [ "**Conclusion:** Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.0755 out of 5 points." ] }, { "cell_type": "markdown", "id": "1098b550", "metadata": { "id": "1098b550" }, "source": [ "## References\n", "\n", "2. Xinkun Nie and Stefan Wager. Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2):299–319, 2021.\n", "\n", "3. Peter M Robinson. Root-n-consistent semiparametric regression. Econometrica: Journal of the Econometric Society, pages 931–954, 1988.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1bb391fb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [ "1098b550" ], "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }